Deep Learning in Scientific Inverse Problems¶

Matteo Ravasi, Assistant Professor - KAUST¶

Data Assimilation Summer School 2023, 26/27 July 2023

Notation¶

Inverse Problems

  • $\mathbf{m}$: model. Quantity that we are interested to know and we think affects the observed data - human organs, rock properties, pressure, atmosferic pressure, ...
  • $\mathbf{d}$: observed data. Quantity that we can physicially measure - CT scan, seismic data, production data, precipitation, ...
  • $\mathbf{G}$ (or $g()$): modelling operator. Set of equations that we think can explain the data by nonlinear/linear combination of model parameters - Radon, seismic convolution model, Navier-stoke (or any other PDE)...

Deep Learning

  • $\mathbf{d}^{(i)}, \mathbf{m}^{(i)} \; i=1,..., N_{train}$: training data. Pairs of data and model that are linked via the modelling operator $\mathbf{G}$. Can be directly acquired with an expensive experiment (usually $\mathbf{m}$ cannot be acquired!) or numerically generated. Compactly represented as $\mathbf{D}$ and $\mathbf{M}$ with training samples over columns.
  • $f_\theta$: parametric function (e.g., Neural network), where $\theta$ are the parameters to be learned.

Inverse problems¶

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{d}, \mathbf{Gm}) + \mathcal{R}(\mathbf{m}) $$

where:

  • $\mathcal{L}$: loss (e.g., $||.||_2^2$)
  • $\mathcal{R}$: regularization

Deep Learning¶

$$ \underset{\theta} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{M}, \mathbf{D}) + \mathcal{R}(\theta) = \underset{\theta} {\mathrm{argmin}} \; \frac{1}{N_{train}} \sum_{i=1}^{N_{train}}\mathcal{L}(\mathbf{m}^{(i)}, f_\theta(\mathbf{d}^{(i)})) + \mathcal{R}(\theta) $$

where:

  • $\mathcal{L}$: loss (e.g., $||.||_2^2$)
  • $\mathcal{R}$: regularization

Goals¶

  • Short intro to Physical problem: CT Scan
  • Short intro to Deep Learning
    • Supervised learning as post-processing (EX3)
    • Deep Image Prior (EX4)
  • Short intro to Proximal Algorithms
    • Variational Inversion (EX2)
    • Plug-and-Play Prior (EX5)
    • Learned iterative solvers (EX6)

Our tools¶

CT Scan¶

The leading medical imaging method to image the human body (i.e., bones, organs, and soft tissues)

CT Scan: fastMRI dataset¶

Open dataset to ease development of Deep Learning methods for medical imaging: https://fastmri.med.nyu.edu

We will use the brain images for CT scan imaging!

CT Scan: how does it work?¶

CT Scan: how does it work?¶

Forward

$$ \mathbf{d} = \mathbf{Gm} + \mathbf{n} $$

Inverse (i.e., imaging)

  • Backprojection: $\mathbf{m}_{bp} = \mathbf{G}^T \mathbf{d}$
  • Filtered Backprojection: $\mathbf{m}_{fbp} = \mathcal{F}(\mathbf{G}^T \mathbf{d})=\mathcal{F}_{fbp}(\mathbf{d})$
  • Least-squares inversion: $\mathbf{m}_{ls} = \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2}||\mathbf{d}-\mathbf{Gm}|| + \mathcal{R}(\mathbf{m})$
  • Other specialized methods (ART, SART, SIRT...)

CT Scan: ASTRA Toolbox¶

Open-source python (and MATLAB) based library for high-performance (CPU and GPU) 2D and 3D tomography.

Supports 2D parallel and fan beam geometries, and 3D parallel and cone beam.

Link to library: https://github.com/astra-toolbox/astra-toolbox

In [3]:
# Load brain image from HF file
f1 = h5py.File('../data/file_brain_AXFLAIR_200_6002471.h5', 'r')
brain = f1['reconstruction_rss'][10]
brain /= brain.max()
nx, ny = brain.shape
In [4]:
# Acquisition geometry parameters
dxdet = 2
nxdet = nx // dxdet
detectors = np.arange(nxdet) * dxdet

nangles = 41
angles = np.linspace(0, np.pi, nangles, False)

# Volume and projection objects
vol_geom = astra.create_vol_geom(nx, ny)
proj_geom = astra.create_proj_geom('parallel', dxdet, nxdet, angles)
proj_id = astra.create_projector('strip', proj_geom, vol_geom)

# Create a sinogram
sinogram_id, sinogram = astra.create_sino(brain, proj_id)
In [5]:
with plt.xkcd():
    fig, axs = plt.subplots(1, 2, figsize=(15, 6))
    axs[0].imshow(brain, vmin=0, vmax=1, cmap='bone')
    axs[0].set_xlabel('X-Location'), axs[0].set_ylabel('Y-Location')
    axs[0].set_title('True Image')
    axs[0].axis('tight')
    axs[1].imshow(sinogram.T, cmap='gray', vmin=0, vmax=200, 
                  extent=(np.rad2deg(angles[0]), np.rad2deg(angles[-1]), 
                          detectors[-1], detectors[0]))
    axs[1].set_xlabel('Angles'), axs[1].set_ylabel('Detector location')
    axs[1].set_title('Sinogram')
    axs[1].axis('tight');
In [6]:
# Create a data object for the reconstruction
rec_id = astra.data2d.create('-vol', vol_geom)

# Define algorithm
cfg = astra.astra_dict('FBP')
cfg['ReconstructionDataId'] = rec_id
cfg['ProjectionDataId'] = sinogram_id
cfg['ProjectorId'] = proj_id
alg_id = astra.algorithm.create(cfg)

# Run algorithm
astra.algorithm.run(alg_id, 1)
brainfbp = astra.data2d.get(rec_id)
In [7]:
# CT linear operator
Cop = CT2D((nx, ny), dxdet, nxdet, angles)

# Back-projection
brainback = Cop.H @ sinogram

# LS Inverse
D2op = pylops.Laplacian(dims=(nx, ny), edge=True, dtype=np.float64)

braininv = pylops.optimization.leastsquares.regularized_inversion(
    Cop, sinogram.ravel(), [D2op], epsRs=[3e0], **dict(iter_lim=50)
)[0]
braininv = braininv.reshape(nx, ny)
In [9]:
with plt.xkcd():
    plot_models([brain, brainback, brainfbp, braininv], 
                ['True Image', 'BP Image', f'FBP Image SNR={pylops.utils.metrics.snr(brain, brainfbp):.2f}',
                 f'LS Image SNR={pylops.utils.metrics.snr(brain, braininv):.2f}'],
                [[0, 1], [500, 2000], [0, 1], [0, 1]])
In [10]:
with plt.xkcd():
    plot_data([sinogram, Cop @ brainfbp, Cop @ braininv], 
                ['Sinogram', 'FBP Sinogram', 'LS Sinogram'],
                [[0, 200], [0, 200], [0, 200]])
In [11]:
with plt.xkcd():
    plot_data([sinogram, sinogram-Cop @ brainfbp, sinogram-Cop @ braininv], 
              ['Sinogram', 'FBP Sinogram', 'LS Sinogram'],
              [[0, 200], [-50, 50], [-50, 50]])

Neural network: Perceptron¶

Neural network: Multi-Layer Perceptron¶

Neural network: Deep Multi-Layer Perceptron¶

Neural network: Training¶

Gradient descent

$$ \boldsymbol\theta_{k+1} = \boldsymbol\theta_k - \alpha_k \nabla \mathcal{L} = \boldsymbol\theta_k - \frac{\alpha_k}{N_{train}} \sum_{i=1}^{N_{train}} \nabla \mathcal{L}^{(i)} $$

Mini-batch gradient descent

$$ \boldsymbol\theta_{k+1} = \boldsymbol\theta_k - \alpha_k \nabla \mathcal{L} = \boldsymbol\theta_k - \frac{\alpha_k}{N_{batch}} \sum_{i=1}^{N_{batch}} \nabla \mathcal{L}^{(i)} $$

where:

  • $\alpha_k$: learning rate
  • $N_{batch}$: number of samples of a batch (i.e., random partition of the training data - $N_{batch}<<N_{train}$

Neural network: Back-propagation¶

Neural network: Back-propagation (derivation)¶

Let's consider the following 3 layers feed-forward NN:

$$ \textbf{z}^{[1]} = \textbf{W}^{[1]}\textbf{x} + \textbf{b}^{[1]}, \quad \textbf{a}^{[1]} = \sigma(\textbf{z}^{[1]}), $$$$ \textbf{z}^{[2]} = \textbf{W}^{[2]}\textbf{a}^{[1]} + \textbf{b}^{[2]}, \quad \textbf{a}^{[2]} = \sigma(\textbf{z}^{[2]}), $$$$ z^{[3]} = \textbf{w}^{[3]T}\textbf{a}^{[2]} + b^{[3]}, \quad a^{[3]} = \sigma(z^{[3]}), $$

followed by a generic loss $\mathcal{L}(y, a^{[3]})$.

We would like to compute: $\partial \mathcal{L} / \partial \textbf{W}^{[2]}$ (under the assumption of sigmoid activation and binary cross-entropy loss).

Neural network: Back-propagation (derivation)¶

Backpropagation = chain rule (implemented from right to left)

$$ \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{[2]}} = \frac{\partial \textbf{z}^{[2]}}{\partial \textbf{W}^{[2]}} \frac{\partial \mathbf{a}^{[2]}}{\partial \mathbf{z}^{[2]}} \frac{\partial z^{[3]}}{\partial \mathbf{a}^{[2]}} \frac{\partial a^{[3]}}{\partial z^{[3]}} \frac{\partial \mathcal{L}}{\partial a^{[3]}} \\ $$

For each term in the above equation (apart from the first one on the right), we simply need to be able to apply the transposed Jacobian to a vector ($\mathbf{J}^T \mathbf{v}$).

Remember:

$$ \mathbf{y} = f(\mathbf{x}) \rightarrow \mathbf{J} = \begin{bmatrix} \frac{\partial y_1}{\partial x_1} & \frac{\partial y_1}{\partial x_2} & ... & \frac{\partial y_1}{\partial x_{N_x}} \\ ... & ... & ... & ... \\ \frac{\partial y_{N_y}}{\partial x_1} & \frac{\partial y_{N_y}}{\partial x_2} & ... & \frac{\partial y_{N_y}}{\partial \theta_{N_x}} \\ \end{bmatrix} \in \mathbb{R}^{[{N_y} \times {N_x}]} $$

Neural network: Back-propagation (derivation)¶

Expanding each term, we get:

$$ \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{[2]}} = \begin{bmatrix} \mathbf{a}^{[1]} & \mathbf{0} & \ldots & \mathbf{0} \\ \mathbf{0} & \mathbf{a}^{[1]} & \ldots & \mathbf{0} \\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{0} & \mathbf{0} & \ldots & \mathbf{a}^{[1]} \end{bmatrix} diag\{\mathbf{a}^{[2]}(1-\mathbf{a}^{[2]})\} \textbf{w}^{[3]} (a^{[3]} - y) $$

where $\frac{\partial a^{[3]}}{\partial z^{[3]}} \frac{\partial \mathcal{L}}{\partial a^{[3]}}=(a^{[3]} - y)$

Supervised learning¶

Goal: train a network to map FBP solutions into True solutions

Given a set of brain images $M=\{\mathbf{m}^{(1)},\mathbf{m}^{(2)},..., \mathbf{m}^{(N_{train})}\}$

  • Reconstruct FBP images: $\mathbf{d}^{(i)} = \mathcal{F}_{fbp}(\mathbf{G} \mathbf{m}^{(i)})$
  • Train network: $$ \underset{\theta} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{M}, \mathbf{D}) + \mathcal{R}(\theta) = \underset{\theta} {\mathrm{argmin}} \; \frac{1}{N_{train}} \sum_{i=1}^{N_{train}}\mathcal{L}(\mathbf{m}^{(i)}, f_\theta(\mathbf{d}^{(i)})) + \mathcal{R}(\theta) $$
  • Apply network to unseen data $\mathbf{d}$: $\hat{\mathbf{m}} = f_\theta({\mathcal{F}_{fbp}(\mathbf{d})})$

Supervised learning - network architecture¶

Many architectures:

  • Fully connected (MLP)
  • Convolution
  • Recurrent
  • Transformer
  • Invertible (Coupling Flow)
  • ...

Convolution is the natural choice for gridded variables (i.e., images, volumes)

Supervised learning - network architecture¶

Supervised learning¶

Time to practice: EX3

Deep Image Prior¶

Regularization vs Preconditioning

  • Regularization: $$\underset{\mathbf{m}} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{d}, \mathbf{Gm}) + \mathcal{R}(\mathbf{m}) \qquad (e.g., \underset{\mathbf{m}} {\mathrm{argmin}} \; \frac{1}{2} ||\mathbf{d}-\mathbf{Gm}||_p^p + \lambda ||\mathbf{Rm}||_p^p) $$

  • Preconditioning: $$\underset{\mathbf{z}} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{d}, \mathbf{G}\mathcal{P}(\mathbf{z})) \qquad (e.g., \underset{\mathbf{z}} {\mathrm{argmin}} \;\frac{1}{2} ||\mathbf{d}-\mathbf{GPz}||_p^p)$$

Deep Image Prior¶

Deep Image Prior = replace $\mathcal{P}$ with an untrained network $f_\theta$ and $\mathbf{z}$ with random noise realization

$$ \underset{\theta} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{d}, \mathbf{G}f_\theta(\mathbf{z})) $$

Rationale: exploit inductive bias of network architecture (optimizer and hyperparameters)

Deep Image Prior: Physical (linear) operators in NNs¶

Challenge: how to implement $\mathbf{G}f_\theta(\mathbf{z})$ and the associated backward?

Let's consider the following chain of operations:

$$ \mathbf{y} = f_3 \cdot f_2 \cdot f_1(\mathbf{x}) $$

where $f_3(\mathbf{x}) = \mathbf{G} \mathbf{x}$

Two routes to incorporate $f_3$ in PyTorch Autograd:

  • Implement a native Pytorch operator $\rightarrow$ very time consuming (if you have lots of code already)
  • Simply tell Pytorch what is the associated Jacobian.

Deep Image Prior: Physical (linear) operators in NNs¶

Simply tell Pytorch what is the associated Jacobian.

Given:

$$ \mathbf{y} = \mathbf{Gx} = \begin{bmatrix} G_{1,1} & G_{1,2} \\ G_{2,1} & G_{2,2} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} G_{1,1} x_1 + G_{1,2} x_2 \\ G_{2,1} x_1 + G_{2,2} x_2 \end{bmatrix} $$

the Jacobian is:

$$ \mathbf{J} = \begin{bmatrix} \frac{\partial y_1}{\partial x_1}=G_{1,1} & \frac{\partial y_1}{\partial x_1}=G_{1,2} \\ \frac{\partial y_1}{\partial x_1}=G_{2,1} & \frac{\partial y_1}{\partial x_1}=G_{2,2} \end{bmatrix} = \mathbf{G} $$

so $\mathbf{J}^T \mathbf{v}=\mathbf{G}^T \mathbf{v}$ (we need to be able to apply the adjoint of the operator to a vector).

In [12]:
# Inserting a PyLops linear operator into Pytorch computational graph

# Create PyLops derivative operator
n = 32
Dop = pylops.FirstDerivative(dims=(n, n))
Dop_torch = pylops.TorchOperator(Dop, device='cpu')

# Forward
x = torch.ones((n, n), requires_grad=True)
y = Dop_torch.apply(x.view(-1))

# Backward
v = torch.ones(n*n)
y.backward(v, retain_graph=True)
jtv = x.grad

# Verify equivalence with Dop.H @ v
np.allclose(jtv.numpy(), Dop.H @ np.ones((n,n)))
Out[12]:
True

Deep Image Prior¶

Time to practice: EX4

Proximal algorithms¶

Gradient-based optimization

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \alpha ||\mathbf{m}||_2^2 + \beta ||\nabla_x\mathbf{m}||_2^2 + \gamma ||\nabla_z\mathbf{m}||_2^2 \qquad \textrm{L2-regularized (linear)} $$

and

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - g(\mathbf{m})||_2^2 + \alpha ||\mathbf{m}||_2^2 + \beta ||\nabla_x\mathbf{m}||_2^2 + \gamma ||\nabla_z\mathbf{m}||_2^2 \qquad \textrm{L2-regularized (nonlinear)} $$

Proximal algorithms¶

Proximal-based optimization

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \alpha ||\mathbf{m}||_1 \qquad \textrm{L1-regularized / Sparse inversion} $$

and

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \alpha ||\mathbf{M}||_* \qquad \textrm{Low-rank matrix approximation} $$

and

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \alpha TV(\mathbf{m}) \qquad \textrm{TV-regularized inversion} $$

and

$$ \underset{\mathbf{m} \in C} {\mathrm{argmin}} \quad \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 \qquad \textrm{Constrained inversion} $$

and many more...

In [13]:
# Proximal gradient CT reconstruction
sigma = 0.01
tv = pyproximal.TV((nx, ny), sigma=sigma, niter=20)
l2 = pyproximal.L2(Cop, b=sinogram.ravel(), niter=20)

L = np.real((Cop.H*Cop).eigs(neigs=1, which='LM')[0])
tau = 0.99 / L

brainpg = pyproximal.optimization.primal.ProximalGradient(
    l2, tv, x0=np.zeros(nx*ny), tau=tau, niter=50, show=True)
brainpg = brainpg.reshape(nx, ny)
Accelerated Proximal Gradient
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.proximal.TV.TV'>
tau = 0.00015775425340130387	beta=5.000000e-01
epsg = 1.0	niter = 50
niterback = 100	acceleration = None

   Itn       x[0]          f           g       J=f+eps*g
     1   8.35400e-02   9.442e+05   1.342e+00   9.442e+05
     2   7.61705e-02   5.435e+05   1.874e+00   5.435e+05
     3   6.98595e-02   3.248e+05   2.376e+00   3.248e+05
     4   6.51900e-02   2.023e+05   2.845e+00   2.023e+05
     5   6.18387e-02   1.323e+05   3.274e+00   1.323e+05
     6   5.95123e-02   9.118e+04   3.670e+00   9.118e+04
     7   5.79664e-02   6.636e+04   4.041e+00   6.636e+04
     8   5.70049e-02   5.085e+04   4.393e+00   5.085e+04
     9   5.64740e-02   4.075e+04   4.731e+00   4.075e+04
    10   5.62555e-02   3.388e+04   5.058e+00   3.388e+04
    11   5.62594e-02   2.898e+04   5.375e+00   2.898e+04
    16   5.77901e-02   1.667e+04   6.844e+00   1.668e+04
    21   5.97702e-02   1.112e+04   8.148e+00   1.113e+04
    26   6.14119e-02   7.823e+03   9.305e+00   7.832e+03
    31   6.26840e-02   5.669e+03   1.033e+01   5.679e+03
    36   6.36714e-02   4.196e+03   1.123e+01   4.207e+03
    41   6.44512e-02   3.157e+03   1.203e+01   3.169e+03
    42   6.45874e-02   2.987e+03   1.218e+01   2.999e+03
    43   6.47181e-02   2.828e+03   1.233e+01   2.841e+03
    44   6.48433e-02   2.679e+03   1.247e+01   2.691e+03
    45   6.49635e-02   2.539e+03   1.261e+01   2.551e+03
    46   6.50790e-02   2.407e+03   1.274e+01   2.420e+03
    47   6.51899e-02   2.283e+03   1.288e+01   2.296e+03
    48   6.52966e-02   2.167e+03   1.300e+01   2.180e+03
    49   6.53992e-02   2.057e+03   1.313e+01   2.070e+03
    50   6.54980e-02   1.953e+03   1.325e+01   1.967e+03

Total time (s) = 5.38
---------------------------------------------------------

In [14]:
# ADMM CT reconstruction
sigma=0.01
l1 = pyproximal.L21(ndim=2, sigma=sigma)
Dop = pylops.Gradient(dims=(nx, ny), edge=True, dtype=Cop.dtype, kind='forward')

L = 8. #np.real((Dop.H*Dop).eigs(neigs=1, which='LM')[0])
tau = 0.99 / L

brainadmm = pyproximal.optimization.primal.ADMML2(
    l1, Cop, sinogram.ravel(), Dop, x0=np.zeros(nx*ny), tau=tau, 
    niter=20, show=True, **dict(iter_lim=20))[0]
brainadmm = brainadmm.reshape(nx, ny)
ADMM
---------------------------------------------------------
Proximal operator (g): <class 'pyproximal.proximal.L21.L21'>
tau = 1.237500e-01	niter = 20

   Itn       x[0]          f           g       J = f + g
     1   7.02514e-02   3.424e+00   1.499e+01   1.841e+01
     2   6.92162e-02   7.121e-01   1.490e+01   1.561e+01
     3   6.86522e-02   3.840e-01   1.473e+01   1.511e+01
     4   6.85469e-02   3.173e-01   1.451e+01   1.483e+01
     5   6.85017e-02   3.107e-01   1.430e+01   1.461e+01
     6   6.85599e-02   2.834e-01   1.411e+01   1.439e+01
     7   6.86347e-02   2.843e-01   1.393e+01   1.422e+01
     8   6.86523e-02   2.651e-01   1.377e+01   1.404e+01
     9   6.87126e-02   2.715e-01   1.362e+01   1.389e+01
    10   6.87267e-02   2.566e-01   1.348e+01   1.374e+01
    11   6.86942e-02   2.638e-01   1.335e+01   1.361e+01
    12   6.85921e-02   2.549e-01   1.322e+01   1.348e+01
    13   6.84441e-02   2.594e-01   1.310e+01   1.336e+01
    14   6.82822e-02   2.514e-01   1.299e+01   1.324e+01
    15   6.81547e-02   2.554e-01   1.289e+01   1.314e+01
    16   6.80010e-02   2.489e-01   1.279e+01   1.304e+01
    17   6.78624e-02   2.538e-01   1.269e+01   1.295e+01
    18   6.76322e-02   2.483e-01   1.260e+01   1.285e+01
    19   6.75801e-02   2.510e-01   1.252e+01   1.277e+01
    20   6.72734e-02   2.478e-01   1.244e+01   1.269e+01

Total time (s) = 29.59
---------------------------------------------------------

In [15]:
with plt.xkcd():
    plot_models([brain, braininv, brainpg, brainadmm], 
                ['True Image', f'LS Image SNR={pylops.utils.metrics.snr(brain, braininv):.2f}',
                 f'PG Image SNR={pylops.utils.metrics.snr(brain, brainpg):.2f}', 
                 f'ADMM Image SNR={pylops.utils.metrics.snr(brain, brainadmm):.2f}'],
                [[0, 1], [0, 1], [0, 1], [0, 1]])

Proximal operator¶

$$ prox_{\tau f} (\mathbf{v}) = (I+ \tau \partial f)^{-1}(\mathbf{x}) = \underset{\mathbf{m}} {\mathrm{argmin}} \quad f(\mathbf{m}) + \frac{1}{2\tau}||\mathbf{m}-\mathbf{v}||_2^2 $$

Source: Parikh N., Proximal Algorithms.

Proximal operator¶

An inverse problem at every step of the iterative scheme... bahhh!!

Luckily many proximal operators have closed-form solution:

$$ \text{Squared L2:} \; prox_{\tau ||\cdot||_2^2} (\mathbf{v}) = \underset{\mathbf{m}} {\mathrm{argmin}} \quad ||\mathbf{m}||_2^2 + \frac{1}{2\tau}||\mathbf{m}-\mathbf{v}||_2^2 = \frac{\mathbf{v}}{1+\tau} $$$$ \text{L1:} \;prox_{\tau ||\cdot||_1} (\mathbf{v}) = \underset{\mathbf{m}} {\mathrm{argmin}} \quad ||\mathbf{m}||_1 + \frac{1}{2\tau}||\mathbf{m}-\mathbf{v}||_2^2 = soft(\mathbf{v}, \tau) \quad \text{(soft-thresholding)} $$

$$ \text{Box constraint:} prox_{Box_{[l, u]}} (\mathbf{v}) = \underset{\mathbf{m} \in Box} {\mathrm{argmin}} \quad \frac{1}{2\tau}||\mathbf{m}-\mathbf{v}||_2^2 = min\{ max \{\mathbf{v}, l\}, u \} $$

See Parikh N., Proximal Algorithms for an extensive list of Proximal operators.

Optimizing mixed (smooth & non-smooth) functionals¶

Let's consider the case

$$ \underset{\mathbf{m}} {\mathrm{argmin}} \; \mathcal{L}(\mathbf{d}, \mathbf{Gm}) + \mathcal{R}(\mathbf{m}) $$

where $\mathcal{L}$ is a smooth function and $\mathcal{R}$ is a non-smooth, convex function;

A smart way of solving this general problem is to use the Forward-Backward splitting algorithm:

$$ 0 \in \nabla \mathcal{L}(\mathbf{m}) + \partial \mathcal{R}(\mathbf{m})\\ 0 \in \mathbf{m}/\alpha - \mathbf{m}/\alpha + \nabla \mathcal{L}(\mathbf{m}) + \partial \mathcal{R}(\mathbf{m})\\ (I-\alpha \nabla \mathcal{L})(\mathbf{m}) \in (I+\alpha \partial \mathcal{R})(\mathbf{m})\\ \mathbf{m}_* = (I+\alpha \partial \mathcal{R})^{-1} (I-\alpha \nabla \mathcal{L})(\mathbf{m}_*)\\ \mathbf{m}_* = prox_{\alpha \mathcal{R}}(I-\alpha \nabla \mathcal{L})(\mathbf{m}_*)\\ $$

Explicit form for fixed-point iterations: $$ \mathbf{m}_{k+1} = prox_{\alpha \mathcal{R}}(I-\alpha \nabla \mathcal{L})(\mathbf{m}_k)\\ $$

Optimizing mixed (smooth & non-smooth) functionals¶

You may recognize the famous ISTA algorithm here!

When $\mathcal{L}=\frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2$ and $\mathcal{R}=||\mathbf{m}||_1$, we have:

$$ \mathbf{m}_{k+1} = prox_{||\cdot||_1} (\mathbf{m}_k - \mathbf{G}^H(\mathbf{G}\mathbf{m}_k-\mathbf{d}))\\ $$

Optimizing mixed (smooth & non-smooth) functionals¶

When both where $f$ and $g$ are non-smooth, convex functions, the Proximal gradient algorithm (or ISTA) cannot be used.

In order to solve this problem, a two steps procedure is required:

  • Splitting: $\mathbf{y}=\mathbf{m}$, such that
$$ \underset{\mathbf{m}, \mathbf{y}} {\mathrm{argmin}} \quad f(\mathbf{m}) + g(\mathbf{y}) $$
  • Augmented Lagrangian: $\text{arg} \underset{\mathbf{m},\mathbf{y}} {\mathrm{min}} \underset{\mathbf{z}} {\mathrm{max}} \mathcal{L}$, where $$ \mathcal{L}=f(\mathbf{m}) + g(\mathbf{y}) + \frac{\rho}{2}||\mathbf{m}-\mathbf{y}||_2^2 + \mathbf{z}^T(\mathbf{m}-\mathbf{y}) $$
  • Alternating minimization:
$$ \mathbf{m}_k = \underset{\mathbf{m}} {\mathrm{argmin}} \quad \mathcal{L}(\mathbf{m}, \mathbf{y}_{k-1}, \mathbf{z}_{k-1}) = prox_{\rho f}(\mathbf{y}_{k-1}-\mathbf{z}_{k-1}) \\ \mathbf{y}_k = \underset{\mathbf{y}} {\mathrm{argmin}} \quad \mathcal{L}(\mathbf{m}_k, \mathbf{y}, \mathbf{z}_{k-1}) = prox_{\rho g}(\mathbf{m}_k+\mathbf{z}_{k-1}) \\ \mathbf{z}_k = \underset{\mathbf{z}} {\mathrm{argmax}} \quad \mathcal{L}(\mathbf{m}_k, \mathbf{y}_k, \mathbf{z}) = \mathbf{z}_{k-1} + \mathbf{m}_k - \mathbf{y}_k $$

This is the famous ADMM algorithm.

Proximal operator - another interpretation¶

$$ prox_{\tau f} (\mathbf{v}) = (I+ \tau \partial f)^{-1}(\mathbf{x}) = \underset{\mathbf{m}} {\mathrm{argmin}} \; f(\mathbf{m}) + \frac{1}{2\tau}||\mathbf{m}-\mathbf{v}||_2^2 $$

Denoising inverse problem with data $\mathbf{v}$ and generic regularizer $f$.

Plug-and-Play (PnP) priors¶

Replace proximal operator with a denoiser of choice (e.g., powerfull NN-based denoisers)

PnP proximal gradient:

$$ \mathbf{m}_{k+1} = Denoiser(\mathbf{m}_k - \mathbf{G}^H(\mathbf{G}\mathbf{m}_k-\mathbf{d}))\\ $$

Plug-and-Play (PnP) priors¶

Time to practice: EX5

Learned iterative solvers¶

Gradient descent

$$ \mathbf{m}_{k+1} = f(\mathbf{m}_k, \partial (\mathcal{L} + \mathcal{R}) / \partial \mathbf{m}; \alpha_{k})= \mathbf{m}_k - \alpha_k \frac{\partial (\mathcal{L} + \mathcal{R})}{\partial \mathbf{m}} $$

Learned iterative solvers¶

Learned Gradient descent $$ \mathbf{m}_{k+1} = f(\mathbf{m}_k, \partial (\mathcal{L} + \mathcal{R}) / \partial \mathbf{m}; \theta)= \mathbf{m}_k - f_\theta\left(\frac{\partial ( \mathcal{L} + \mathcal{R})}{\partial \mathbf{m}} \right) $$

Training: $$ \underset{\theta} {\mathrm{argmin}} \; \frac{1}{N_{train}} \sum_{i=1}^{N_{train}} \sum_{k=1}^{N_{it}}\mathcal{L}(\mathbf{m}^{(i)}, F_\theta(\mathbf{d}^{(i)}, \mathbf{m}_0^{(i)})|_k) + \mathcal{R}(\theta) $$

Learned iterative solvers¶

Proximal gradient

$$ \mathbf{m}_{k+1} = f(\mathbf{m}_k, \partial (\mathcal{L} + \mathcal{R}) / \partial \mathbf{m}; \alpha_{k})= prox_{\alpha \mathcal{R}} \left( \mathbf{m}_k - \alpha_k \frac{\partial (\mathcal{L} + \mathcal{R})}{\partial \mathbf{m}} \right) $$

Learned Proximal Gradient $$ \mathbf{m}_{k+1} = f_\theta \left( \mathbf{m}_k - \alpha_k \frac{\partial (\mathcal{L} + \mathcal{R})}{\partial \mathbf{m}} \right) $$

Learned iterative solvers¶

Other options inspired by classical optimizers (but not directly linked to a specific one).

  • Adler and Oktem, 2017
$$ \mathbf{m}_{k+1} = \mathbf{m}_k - f_\theta \left(\mathbf{m}_k \oplus \frac{\partial \mathcal{L}}{\partial \mathbf{m}} \oplus \frac{\partial \mathcal{R}}{\partial \mathbf{m}} \right) $$
  • Adler and Oktem, 2017 (with memory $\mathbf{s}_0 = \mathbf{0}$)
$$ \Delta \mathbf{m}_{k+1} \oplus \mathbf{s}_{k+1} = \mathbf{m}_k - f_\theta \left(\mathbf{m}_k \oplus \mathbf{s}_k \oplus \frac{\partial \mathcal{L}}{\partial \mathbf{m}} \oplus \frac{\partial \mathcal{R}}{\partial \mathbf{m}} \right) $$$$ \mathbf{m}_{k+1} = \mathbf{m}_k - \Delta \mathbf{m}_{k+1} $$

where $\oplus$ applies concatenation.

Learned iterative solvers¶

Time to practice: EX6

Summary¶

Classical inversion algorithms:

  • Gradient-based: $$\underset{\mathbf{m}} {\mathrm{argmin}} \; \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \mathcal{R}(\mathbf{m}) \qquad (\mathcal{R} \; \text{differentiable}) $$

  • Proximal-based: $$\underset{\mathbf{m}} {\mathrm{argmin}} \; \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \mathcal{R}(\mathbf{m}) \qquad (\mathcal{R} \; \text{proximable}) $$

Summary¶

Zoo of deep learning-based inversion algorithms:

  • Supervised learning: $$ \underset{\theta} {\mathrm{argmin}} \; \frac{1}{N_{train}} \sum_{i=1}^{N_{train}}\mathcal{L}(\mathbf{m}^{(i)}, f_\theta(\mathbf{d}^{(i)})) + \mathcal{R}(\theta) $$
  • Deep Image prior (totally unsupervised): $$ \underset{\theta} {\mathrm{argmin}} \; \frac{1}{2} ||\mathbf{d} - \mathbf{G}f_\theta(\mathbf{z})||_2^2 $$
  • PnP prior (partially supervised): $$ \underset{\mathbf{m}} {\mathrm{argmin}} \; \frac{1}{2} ||\mathbf{d} - \mathbf{Gm}||_2^2 + \mathcal{R}(\mathbf{m}) \quad (\mathcal{R} \; \text{associated to a denoising alg.}) $$
  • Learned solvers (supervised): $$ \underset{\theta} {\mathrm{argmin}} \; \frac{1}{N_{train}} \sum_{i=1}^{N_{train}} \sum_{k=1}^{N_{it}}\mathcal{L}(\mathbf{m}^{(i)}, F_\theta(\mathbf{d}^{(i)}, \mathbf{m}_0^{(i)})|_k) + \mathcal{R}(\theta) $$

Links¶

PyLops:

  • Github: https://github.com/PyLops/pylops
  • Doc: http://pylops.readthedocs.io
  • Tutorial: https://github.com/PyLops/pylops_pydata2020

PyProximal:

  • Github: https://github.com/PyLops/pyproximal
  • Doc: http://pyproximal.readthedocs.io
  • Tutorial: https://github.com/PyLops/pylops_notebooks/tree/master/official/iicuseminar_2022

ASTRA Toolbox:

  • Github: https://github.com/astra-toolbox/astra-toolbox
  • Doc: https://www.astra-toolbox.com
  • Tutorial: https://visielab.uantwerpen.be/astra-training